{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import time\n",
    "searchPath=os.path.abspath('..')\n",
    "sys.path.append(searchPath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.datasets import load_iris\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from knn.knn_base import KNN\n",
    "from knn.knn_kdtree import KNNKdTree\n",
    "from utils.data_generater import random_points\n",
    "from utils.plot import plot_knn_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getData(number):\n",
    "    data = random_points(2, number)\n",
    "    label = [0] * (number // 2) + [1] * (number // 2)\n",
    "    return np.array(data), np.array(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataList = []\n",
    "labelList = []\n",
    "for num in [30, 500, 1000, 2000, 5000, 10000, 50000, 400000]:\n",
    "    data, label = getData(num)\n",
    "    dataList.append(data)\n",
    "    labelList.append(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\"knn\":KNN(), \"kdtree\":KNNKdTree()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model = knn, dataNum = 30, takeTime = 0.00391\n",
      "model = kdtree, dataNum = 30, takeTime = 0.00417\n",
      "model = knn, dataNum = 500, takeTime = 0.03806\n",
      "model = kdtree, dataNum = 500, takeTime = 0.00856\n",
      "model = knn, dataNum = 1000, takeTime = 0.05203\n",
      "model = kdtree, dataNum = 1000, takeTime = 0.01386\n",
      "model = knn, dataNum = 2000, takeTime = 0.1387\n",
      "model = kdtree, dataNum = 2000, takeTime = 0.02863\n",
      "model = knn, dataNum = 5000, takeTime = 0.28177\n",
      "model = kdtree, dataNum = 5000, takeTime = 0.07277\n",
      "model = knn, dataNum = 10000, takeTime = 0.47404\n",
      "model = kdtree, dataNum = 10000, takeTime = 0.16433\n",
      "model = knn, dataNum = 50000, takeTime = 2.0887\n",
      "model = kdtree, dataNum = 50000, takeTime = 0.93545\n"
     ]
    }
   ],
   "source": [
    "for data, label in zip(dataList, labelList):\n",
    "    for name, model in models.items():\n",
    "        startTime = time.time()\n",
    "        model.fit(data, label)\n",
    "        for i in range(5):\n",
    "            model.predict([0.3, 0.2])\n",
    "        print(\"model = %s, dataNum = %s, takeTime = %s\"%(name, len(data), round(time.time() - startTime, 5)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "point = (0.3, 0.2)\n",
    "model = KNNKdTree()\n",
    "model.fit(dataList[0], labelList[0])\n",
    "plot_knn_predict(model, dataList[0], labelList[0], point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
